/*
Copyright 2023, The University of Texas at Austin
All rights reserved.

THIS FILE IS PART OF THE DRAM FAULT ERROR SIMULATION FRAMEWORK

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:


1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.


2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.


3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.


THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
*/

/**
 * @file: Fault.hh
 * @author: Jungrae Kim <dale40@gmail.com> and Seong-Lyong Gong
 * Fault declaration
 */

#ifndef __FAULT_HH__
#define __FAULT_HH__

#include <algorithm>
#include <random>
#include <string>
#include <memory>
#include <cstring>
#include "Config.hh"
#include "FaultDomain.hh"
#include "common.hh"
#include "limits.h"
#include "message.hh"

extern int BANKSPERBEAT;
extern unsigned long long SBANK_MASK_DEGRADE;
extern int column_address_bits; 
extern int row_address_bits;
extern int module;
extern char DRAMTYPE[16];
extern int numofBanks;

//------------------------------------------------------------------------------
extern std::default_random_engine randomGenerator;

/**@addtogroup Fault_Management
 * @{
 * @class Fault
 * @brief This class defines characteristics of each fault type
 * @details Each fault type is defined by multiple attributes. For example,
 * permanent single bit fault has isTransient=false, numDQ=1, isSigleBeat=true,
 * isMultiColumn=false, isMultiRow=false, and isChannel=false. If a fault is
 * chosen to inject, random bit errors are generated by following these
 * attributes.
 */
class Fault {
 public:
  Fault(std::string _name);
  Fault(FaultDomain *_fd, std::string _name, ADDR _mask, bool _isInherent,
        bool _isTransient, int _numDQ, bool _isSingleBeat, bool _isMultiRow,
        bool _isMultiColumn, bool _isChannel,
        unsigned long long _affectedBlkCount, int _banksperBeat);
  virtual ~Fault();
  

 public:
  ADDR getAddr() { return addr; }
  ADDR getMask() { return mask; }
  std::string getName() { return name; }
  ADDR getEffectiveMask() {
    if (effective_mask == 0) {
      return mask;
    } else {
      return effective_mask;
    }
  }; 
  bool getIsInherent() { return isInherent; }
  bool getIsTransient() { return isTransient; }
  int getNumDQ() { return numDQ; }
  bool getIsSingleDQ() { return numDQ == 1; }
  bool getIsSingleBeat() { return isSingleBeat; }
  bool getIsMultiColumn() { return isMultiColumn; }
  bool getIsMultiRow() { return isMultiRow; }
  bool getIsChannel() { return isChannel; }
  unsigned long long getAffectedBlkCount() { return affectedBlkCount; }
  int getChipID() { return chipPos; }
  int getPinID() { return pinPos[0]; }
  int getPinID1() { return pinPos[1]; }
  int getNumInherents() {
    if(isInherent){
      return numInherentFaults;
    }else{
      assert(0);
      return 0;
    }
  }
  void update_pinpos(bool group = false,int group_size = 4);
  std::vector<std::shared_ptr<Fault>> detailed_faults;

  double getCellFaultRate() { return cellFaultRate; }
  //vector of smart pointers to faults
  void print(FILE *fd = stdout) {
    fprintf(fd, "%s ADDR=%016llx MASK=%016llx (T=%d)\n", name.c_str(), addr,
            getMask(), isTransient);
  }
  bool overlap(Fault *other) {
    if (other == NULL) return true;
    // Based on "FaultSim: A Fast, Configurable Memory-Reliability Simulator for
    // Conventional and 3D-Stacked Systems"
    // if fault has detailed faults, check if any of them overlap
    if (detailed_faults.size() > 0) {
      for (auto &f : detailed_faults) {
        if (f->overlap(other)) 
          return true;
      }
      return false;
    }
    if (other->detailed_faults.size() > 0) {
      for (auto &f : other->detailed_faults) {
        if (f->overlap(this)) 
          return true;
      }
      return false;
    }
    // if fault has no detailed faults, check if they overlap
    ADDR combinedMask = getEffectiveMask() | other->getEffectiveMask();
    ADDR combinedAddr = ~(addr ^ other->addr);
    ADDR total = combinedMask | combinedAddr | 7; // 7 is for 8 bytes
    if (total == 0xFFFFFFFFFFFFFFFFull) {
      return true;
    }
    return false;
  }

  // virtual function with arbitrary arguments
  // return vector<int> type
  virtual std::vector<int> setFinegrainedMask(std::vector<int> &pos){
    effective_mask = mask;
    return std::vector<int>();
  }
    
  virtual void genRandomError(CacheLine *line) {
    //        if (beatCount*numDQ>64) {
    // printf("beatCount*numDQ>64\n");
    std::uniform_int_distribution<unsigned long long> randDist =
        std::uniform_int_distribution<unsigned long long>(0,
                                                          (1ULL << numDQ) - 1);
    if (beatCount%numBanks_perBeat != 0){
      assert(0);
      // beatCount should be the multiple of numBanks
    }
    int beatCountperBank = beatCount/numBanks_perBeat;
    bool noError = true;
    std::random_shuffle(bank_list, bank_list + numBanks_perBeat);
    while (noError) {
      for (int idx =0; idx<numBank_errors; idx++){
        int bankidx = bank_list[idx];

        for (int beat = beatCountperBank*bankidx; 
                          beat < beatCountperBank*(bankidx+1); beat++) {
          unsigned long long randValue = randDist(randomGenerator);
          for (int pin = 0; pin < numDQ; pin++) {
            if ((randValue >> pin) & 1) {
              // GONG: below if-else are needed for the case when no bit error
              // overlap is allowed (e.g., scenario tests of multi bit errors)
              // if(!line->getBit(line->getChannelWidth()*(beat+beatStart)+pinPos[pin])){
              line->invBit(line->getChannelWidth() * (beat + beatStart) +
                          pinPos[pin]);
              noError = false;
              //}else{//change bit pos
              //	beatEnd = beatStart = rand()%fd->getBeatHeight();
              //	pinPos[pin] = chipPos * fd->getChipWidth() +
              // rand()%fd->getChipWidth();
              //}
            }
          }
        }
      }
    }
    line->errorDQ += numDQ;
    /*
            } else {
    printf("beadCount*numDQ<=64\n");
                // exclude 0
                std::uniform_int_distribution<unsigned long long> randDist =
    std::uniform_int_distribution<unsigned long long>(1,
    (1ULL<<(beatCount*numDQ))-1);
                unsigned long long randValue = randDist(randomGenerator);
                for (int beat = 0; beat < beatCount; beat++) {
                    for (int pin = 0; pin<numDQ; pin++) {
                        if ((randValue>>(beat*numDQ+pin))&1) {
                            line->invBit(line->getChannelWidth()*(beat+beatStart)+pinPos[pin]);
                        }
                    }
                }
            }
                            */
  }
  virtual void genRandomErrors(CacheLine *line, int numErrors, bool chipRand){};

  // static
  static Fault *genRandomFault(std::string type,
                               FaultDomain *fd);  //!< random fault generation
  bool overlapped = false;  //!< overlap with inherent faults
 public:
  FaultDomain *fd;
  std::string name;
  ADDR addr;           //!< address used for fault overlap checking
  ADDR mask;           //!< mask used for fault overlap checking
  ADDR effective_mask; //!< mask used for fault overlap checking
  bool isInherent;     //!< inherent or operational
  bool isTransient;    //!< transient or permanent
  int numDQ;           //!< number of DQs involved
  bool isSingleBeat;   //!< single or multiple beats
  bool isMultiColumn;  //!< multiple or single column
  bool isMultiRow;     //!< multple or single row
  bool isChannel;      //!< channel related faults
  bool isMultipleBanks_perBeat; //!< factor to represent Bank level ECC
  int numBanks_perBeat;//!< number of banks per beat (burstLength)

  unsigned long long affectedBlkCount;
  int beatStart, beatEnd, beatCount;
  int chipPos;
  int pinPos[128];
  double cellFaultRate;
  int numInherentFaults;
  int *bank_list;
  int numBank_errors;

  static const bool INHERENT = true;
  static const bool OPERATIONAL = false;
  static const bool TRANSIENT = true;
  static const bool PERMANENT = false;
  static const bool SINGLE_BEAT = true;
  static const bool MULTI_BEAT = false;
  static const bool MULTI_COLUMN = true;
  static const bool SINGLE_COLUMN = false;
  static const bool MULTI_ROW = true;
  static const bool SINGLE_ROW = false;
  static const bool CHANNEL = true;
  static const bool NO_CHANNEL = false;
};
/* @} */

/**
 * @brief single bit fault class
 */
class SingleBitFault : public Fault {
 public:
  SingleBitFault(FaultDomain *fd, bool _isTransient)
      : Fault(fd, "Sbit", SBIT_MASK, OPERATIONAL, _isTransient, 1, SINGLE_BEAT,
              SINGLE_ROW, SINGLE_COLUMN, NO_CHANNEL, 0, 1) {
              }
};

/**
 * @brief	single word fault class
 */
class SingleWordFault : public Fault {
 public:
  SingleWordFault(FaultDomain *fd, bool _isTransient, int _numDQ)
      : Fault(fd, "Sword", SWORD_MASK, OPERATIONAL, _isTransient, _numDQ,
              SINGLE_BEAT, SINGLE_ROW, SINGLE_COLUMN, NO_CHANNEL, 0, 1) {
    // assert(numDQ != 1);
  }
};

/**
 * @brief single pin fault class
 */
class SinglePinFault : public Fault {
 public:
  SinglePinFault(FaultDomain *fd, bool _isTransient)
      : Fault(fd, "Spin", CHANNEL_MASK, OPERATIONAL, _isTransient, 1,
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              ((MRANK_MASK ^ DEFAULT_MASK) + 1) / 8, BANKSPERBEAT) {}
};

/**
 * @brief	single chip fault class
 */
class SingleChipFault : public Fault {
 public:
  SingleChipFault(FaultDomain *fd, bool _isTransient, int _numDQ)
      : Fault(fd, "Schip", CHANNEL_MASK, OPERATIONAL, _isTransient, _numDQ,
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              ((MBANK_MASK ^ DEFAULT_MASK) + 1) / 8, BANKSPERBEAT) {}
};

/**
 * @brief	channel fault class
 */
class ChannelFault : public Fault {
 public:
  ChannelFault(FaultDomain *fd, bool _isTransient)
      : Fault(fd, "Channel", CHANNEL_MASK, OPERATIONAL, _isTransient,
              fd->getChipWidth(), MULTI_BEAT, MULTI_COLUMN, MULTI_ROW, CHANNEL,
              ((MRANK_MASK ^ DEFAULT_MASK) + 1) / 8, BANKSPERBEAT) {}
  void genRandomError(CacheLine *line) {
    bool noError = true;
    while (noError) {
      std::uniform_int_distribution<unsigned long long> randDist =
          std::uniform_int_distribution<unsigned long long>(0, ULLONG_MAX);
      unsigned long long randValue = randDist(randomGenerator);
      int offset = 0;
      for (int i = line->getBitN() - 1; i >= 0; i--) {
        if ((randValue >> offset) & 1) {
          line->invBit(i);
          noError = false;
        }
        offset++;
        if (offset == 64) {
          randValue = randDist(randomGenerator);
          offset = 0;
        }
      }
    }
  }
};

/**
 * @brief	single column fault class
 */
class SingleColumnFault : public Fault {
 public:
  SingleColumnFault(FaultDomain *fd, bool _isTransient, int _numDQ)
      : Fault(fd, "Scol", SCOL_MASK, OPERATIONAL, _isTransient, _numDQ,
              SINGLE_BEAT, MULTI_ROW, SINGLE_COLUMN, NO_CHANNEL, 0, 1) {
    double p = ((double)rand()) / RAND_MAX;
  }
  std::vector<int> setFinegrainedMask(std::vector<int> &pos) {
    /*
      pos[0] = _burstRowlength, pos[1] = _num_groups
      if pos[2],[3] exist, it will specify the rowbit and groupbit position
      burstRow_length means the number of rows that suffer error in a burst
      num_groups will generate that amount of burst errors
      1024, 2 will generate 2048 rows of error
      column_address_bits = 10  
    */    
    int _burstRow_length = pos[0];
    int _num_groups = pos[1];
    effective_mask = DEFAULT_MASK;
    int num_row_bit = ceil(log2(_burstRow_length));
    int num_group_bit = ceil(log2(_num_groups));
    if (num_row_bit + num_group_bit > row_address_bits) {
      std::cout << "Error: too many rows or groups" << std::endl;
      exit(1);
    }
    int random_row_bitpos = 0;
    int random_group_bitpos = 0;
    if(pos.size() == 2){
      random_row_bitpos = rand() % (row_address_bits - num_row_bit - num_group_bit);
      random_group_bitpos = rand() % (row_address_bits - num_row_bit - num_group_bit - random_row_bitpos);
    } else {
      assert(pos.size() == 4);
      random_row_bitpos = pos[2];
      random_group_bitpos = pos[3];
    }
    effective_mask = effective_mask | ((1 << num_row_bit) - 1) << (column_address_bits + random_row_bitpos);
    effective_mask = effective_mask | ((1 << num_group_bit) - 1) << (column_address_bits + num_row_bit + random_group_bitpos);
    return std::vector<int>{random_row_bitpos, random_group_bitpos};
  }
};

/**
 * @brief	single row fault class
 */
class SingleRowFault : public Fault {
 public:
  SingleRowFault(FaultDomain *fd, bool _isTransient, int _numDQ)
      : Fault(fd, "Srow", SROW_MASK, OPERATIONAL, _isTransient, _numDQ,
              MULTI_BEAT, SINGLE_ROW, MULTI_COLUMN, NO_CHANNEL,
              (((SROW_MASK ^ DEFAULT_MASK) >> 15) + 1) / 8, BANKSPERBEAT) {

    double p = ((double)rand()) / RAND_MAX;
    if (_numDQ == 1){
      mask = SROW_MASK | combo_mask;
    }
    affectedBlkCount = 1;
    }
};

class LocalWordlineFault : public Fault{
  public:
  LocalWordlineFault(FaultDomain *fd, bool _isTransient, int special)
      : Fault(fd, "Lwordline", SROW_MASK, OPERATIONAL, _isTransient, 1,
              MULTI_BEAT, SINGLE_ROW, MULTI_COLUMN, NO_CHANNEL,
              (((SROW_MASK ^ DEFAULT_MASK) >> 15) + 1) / 8, BANKSPERBEAT) {

    double p = ((double)rand()) / RAND_MAX;
    if (special == 1){
      if (strcmp(DRAMTYPE,"HBM3")==0){
        if(p<0.027){ // DUE rate of local_wordline
          numDQ = fd->getChipWidth();
          this->name = "RDEC";
          mask = SWD_MASK ;
          update_pinpos(true,4);
        } else {
          numDQ = 2;
          update_pinpos(true,2);
        }
      } else{
        if(p<0.027){ // DUE rate of local_wordline
          numDQ = fd->getChipWidth();
          this->name = "RDEC";
          mask = SWD_MASK ;
          update_pinpos(true,2);
        } else {
          numDQ = 1;
          update_pinpos(true);
        }
      }
    } else if (special == 2){
      // TODO: generalize this
      mask = SROW_MASK | combo_mask;
      this->name = "SWD";
      if (strcmp(DRAMTYPE,"HBM3")==0){
        numDQ = 4;
        update_pinpos(true,4);
      } else if (strcmp(DRAMTYPE,"DDR5")==0 || strcmp(DRAMTYPE,"LPDDR5")==0){
        numDQ = 2;
        update_pinpos(true,2);
      } else{
        numDQ = 1;
        update_pinpos();
      }
    }

    
    affectedBlkCount = 1;
    }
};

class BLSAFault : public Fault{
  public:
    BLSAFault(FaultDomain *fd, bool _isTransient)
      : Fault(fd, "BLSA", BLSA_MASK, OPERATIONAL, _isTransient, 1,
              SINGLE_BEAT, MULTI_ROW, SINGLE_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
      
    double p = ((double)rand()) / RAND_MAX;
    affectedBlkCount = 2048; // two 1k blocks
    update_pinpos();

  }
};

class BankPatternFault : public Fault{
  public:
    BankPatternFault(FaultDomain *fd, bool _isTransient, int _special)
      : Fault(fd, "Bank_pattern", BANK_PATTERN_MASK, OPERATIONAL, _isTransient, fd->getChipWidth(),
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
      
    double p = ((double)rand()) / RAND_MAX;
    // choose nuber from 3-9
    int num = rand() % 7 + 3;
    mask = mask | (1 << num);
    if (_special == 0){
      //diff by 8 banks
      mask = mask | (8 << (column_address_bits + row_address_bits));
      numDQ=1;
    } else if (_special == 2){
      //many banks
      mask = mask | ((numofBanks - 1) << (column_address_bits + row_address_bits));
    } else if (_special == 3){
      //not diff by 8 banks, but two bank
      num = rand() % 3 + 1;
      mask = mask | (1 << (column_address_bits + row_address_bits + num));
    }
    affectedBlkCount = 16*1024*2; // 16k block(subbank) * 2 banks
    update_pinpos();
  }
  
};

class CDECFault : public Fault{
  public:
    CDECFault(FaultDomain *fd, bool _isTransient, int _special)
      : Fault(fd, "CDEC", CDEC_MASK, OPERATIONAL, _isTransient, fd->getChipWidth(),
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
      
    double p = ((double)rand()) / RAND_MAX;
    if (_special == 0){ 
      //decoder_multi_col
      // choose nuber from 3-9
      int num = rand() % 7 + 3;
      mask = mask | (1 << num);
      if(p > 0.0848){ // col_single_bank  DUE prob is 0.0848 ==> 4DQ happen for 0.0848
        this->name = "CSL";
        mask = CSL_MASK;
        numDQ = 1;
        update_pinpos();
      } else {
        numDQ = fd->getChipWidth();
      }

      
    } else if (_special == 1){
      if(p>0.395){ // col_single_bank  DUE prob is 0.395 ==> 4DQ happen for 0.395
        this->name = "CSL";
        mask = CSL_MASK;
        numDQ = 1;
        update_pinpos();
      }else{
        numDQ = fd->getChipWidth();
      }
    }
    affectedBlkCount = 16*1024*2;
    if (strcmp(DRAMTYPE,"LPDDR5_SPLIT")==0){
      numDQ=8;
      update_pinpos(true,8);
    } 
  }
};

class CSLFault : public Fault{
  public:
    CSLFault(FaultDomain *fd, bool _isTransient, int _special)
      : Fault(fd, "CSL", CSL_MASK, OPERATIONAL, _isTransient, 1,
              SINGLE_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
      
    double p = ((double)rand()) / RAND_MAX;
    if (_special == 0){
      // It is multi csl, column and bank
      // would not exceed 16k rows, but no guarantee for others
      mask = mask | (0xFFFFFFFFFFFFFFFFULL >> (64-(int)(column_address_bits)));
      numDQ = fd->getChipWidth();
    }else if (_special == 1){
      // It is csl_single_bank. Select 1 column
      // which is two column error.
      int num = rand() % (column_address_bits-3);
      mask = mask | (0x1 << (num+3));
    } else if(_special == 2){
      mask = mask;
    }
    affectedBlkCount = 16*1024; // 16k block(subbank)
    update_pinpos();
  }
};

class MultiModuleFault : public Fault{
  public:
    MultiModuleFault(FaultDomain *fd, bool _isTransient,int _special)
      : Fault(fd, "Multi_module", CHANNEL_MASK, OPERATIONAL, _isTransient, fd->getChipWidth(),
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
      
    // Add at most 32 row faults or 
    // at most 2 column faults
    this->numDQ = fd->getChipWidth();
    double p = ((double)rand()) / RAND_MAX;

    if (p < 0.5){
    this->detailed_faults.push_back(
        std::make_shared<BankPatternFault>(fd, _isTransient, 0));
        affectedBlkCount = 16*1024 *2;
    } else {
      int n = rand() % 32 + 1;
      for (int i = 0; i < n; i++){
        auto new_fault = std::make_shared<SingleRowFault>(fd, _isTransient, fd->getChipWidth());
        new_fault-> addr = (this->addr & (~SBANK_MASK)) | ((RAND_MAX * ((ADDR)rand()) + rand())& (SBANK_MASK)); // only guarantee the same bank
      
        this->detailed_faults.push_back(
          new_fault
        );
      }
      affectedBlkCount = n;
    }
    update_pinpos();
  }
};

class RDECFault : public Fault{
  public:
    RDECFault(FaultDomain *fd, bool _isTransient, int _special)
      : Fault(fd, "RDEC", RDEC_MASK, OPERATIONAL, _isTransient, fd->getChipWidth(),
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
    double r = ((double)rand()) / RAND_MAX;
    if (_special == 0){
      // lwl_sel, lwl_sel2
      int num = rand() % 63  + 1;
      mask = LWL_MASK | (num << column_address_bits + subarray_address_bits);
      
      // TODO: generalize this
      // When we assume there are 1 MAT / 1DQ:
      // 3/5 change 2 DQ, 2/5 chance 1 DQ
      if (strcmp(DRAMTYPE,"HBM3")==0 ){
        if (r < 0.031) { // DUE rate of lwl_sel is 0.031
          numDQ = fd->getChipWidth();
        } else {
          this->name = "SWD";
          numDQ = 4;
        }
        update_pinpos(true,4);
      }else{
        if (r < 0.031) { // DUE rate of lwl_sel is 0.031
          numDQ = fd->getChipWidth();
        } else {
          this->name = "SWD";
          numDQ = 1;
        }
        update_pinpos(true,4);
      }
      affectedBlkCount = num;
    } else if (_special == 1){
      // On RDEC_MASK, add 1 to randomly on 17th to 22nd bit
      int x = rand()%64;
      mask = RDEC_MASK | (x << (16));
      affectedBlkCount = x;
      if (r<0.18){ // row_decoder DUE rate
        numDQ = 2;
      } else {
        this->name = "SWD";
        numDQ = 1;
      }
      update_pinpos(true,4);
      /*
      if (strcmp(DRAMTYPE,"LPDDR5_SPLIT")==0 ){
        numDQ=8;
        update_pinpos(true,8);
      }
      */
    }
    
  }
};

class SWDFault : public Fault{
  public:
    SWDFault(FaultDomain *fd, bool _isTransient, int _special)
      : Fault(fd, "SWD", SWD_MASK, OPERATIONAL, _isTransient, 2,
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
      if (_special == 0){
        // one cluster
        mask = SWD_MASK;
        double r = ((double)rand()) / RAND_MAX;
        // TODO: generalize this

        // When we assume there are 2 MAT / 1DQ:
        // 3/9 change 2 DQ, 6/9 chance 1 DQ
        
        if(strcmp(DRAMTYPE,"HBM3")==0){
          if(r<0.19){
            numDQ = fd->getChipWidth();
            this->name = "RDEC";
            mask = RDEC_MASK;
          } else {
            numDQ = 4;
          }                    
          update_pinpos(true);
        } else {
          if(r<0.19){
            numDQ = 4;
            this->name = "RDEC";
            mask = RDEC_MASK;
          } else {
            numDQ = 1;
          }
          update_pinpos(true,4);
        }
        
      } else if (_special == 1){
        mask = SWD_MASK | combo_mask;
        if(strcmp(DRAMTYPE,"HBM3")==0){
          numDQ=4;
          update_pinpos(true,4);
        } else{
          if (combo_mask == 0){
            numDQ = 2;
          }else {
            numDQ = 1;
          }
          update_pinpos(true,4);
        }
      }
      affectedBlkCount = 2;
  }
};

class DistBitFault : public Fault{
  public:
    DistBitFault(FaultDomain *fd, bool _isTransient, int _special)
      : Fault(fd, "Dist_bit", DEFAULT_MASK, OPERATIONAL, _isTransient, 1,
              SINGLE_BEAT, SINGLE_ROW, SINGLE_COLUMN, NO_CHANNEL,
              DEFAULT_MASK, BANKSPERBEAT) {
      if (_special == 0){
        for (int i = 0; i< 10; i++){
          this->detailed_faults.push_back(
            std::make_shared<SingleWordFault>(fd, _isTransient, 4));
        }
      }
  }
};



/**
 * @brief	single bank fault class
 */
class SingleBankFault : public Fault {
 public:
  SingleBankFault(FaultDomain *fd, bool _isTransient, int _numDQ)
      : Fault(fd, "Sbank", SBANK_MASK, OPERATIONAL, _isTransient, _numDQ,
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL, 0, BANKSPERBEAT) {
    double p = ((double)rand()) / RAND_MAX;

  }
};

/**
 * @brief	Multi-bank fault class
 */
class MultiBankFault : public Fault {
 public:
  MultiBankFault(FaultDomain *fd, bool _isTransient, int _numDQ)
      : Fault(fd, "Mbank", MBANK_MASK, OPERATIONAL, _isTransient, _numDQ,
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL, 0, BANKSPERBEAT) {
    double p = ((double)rand()) / RAND_MAX;


  }
};

/**
 * @brief	Multi-rank fault class
 */
class MultiRankFault : public Fault {
 public:
  MultiRankFault(FaultDomain *fd, bool _isTransient, int _numDQ)
      : Fault(fd, "Mrank", MRANK_MASK, OPERATIONAL, _isTransient, _numDQ,
              MULTI_BEAT, MULTI_ROW, MULTI_COLUMN, NO_CHANNEL, 0, BANKSPERBEAT) {
    double p = ((double)rand()) / RAND_MAX;

  }
};

/**
 * @brief	inherent fault class
 */
class InherentFault : public Fault {
 public:
  InherentFault(FaultDomain *fd, double _cellFaultRate,
                double _newWeakCellRate = 0.0)
      : Fault("Inherent") {
    cellFaultRate = _cellFaultRate;
    newWeakCellRate = _newWeakCellRate;
  }

  void genRandomErrors(CacheLine *line, int numErrors, bool chipRand) {
    int ChannelWidth = line->getChannelWidth();
    int chipWidth = line->getChipWidth();
    int height = line->getBeatHeight();
    int chip = rand() % (ChannelWidth / chipWidth);
    numInherentFaults = numErrors;
    if (numErrors <= Twelv) {
      for (int i = 0; i < numErrors; i++) {
        if (chipRand) chip = rand() % (ChannelWidth / chipWidth);
        int bitPos = rand() % chipWidth;
        int beatPos = rand() % height;
        int bit = ChannelWidth * beatPos + chip * chipWidth + bitPos;
        if (line->bitArr[bit] == 0) {
          line->setBit(bit, true);
        } else {
          // conflict with the previous fault
          // do it again
          i--;
        }
      }
    } else if (numErrors <= DoubleSingleSingle18) {
      int firstChip, secondChip;
      int iter;
      numInherentFaults=0;
      if (numErrors == DoubleSingleSingle18)
        iter = Triple;
      else
        iter = Double;
      for (int j = 0; j < iter; j++) {
        if (j == 0) {
          firstChip = chip = rand() % (ChannelWidth / chipWidth);
        } else if (j == 1) {
          while (firstChip == chip) {
            chip = rand() % (ChannelWidth / chipWidth);
          }
          secondChip = chip;
        } else {
          while (firstChip == chip || secondChip == chip) {
            chip = rand() % (ChannelWidth / chipWidth);
          }
        }
        int err = Double;
        if (numErrors == SingleSingle18 || numErrors == SingleSingle10)
          err = Single;
        else if (numErrors == DoubleDouble18 || numErrors == DoubleDouble9 ||
                 numErrors == DoubleDouble10)
          err = Double;
        else if (j > 0)
          err = Single;
        numInherentFaults += err;
        for (int i = 0; i < err; i++) {
          int bitPos = rand() % chipWidth;
          int beatPos = rand() % height;
          int bit = ChannelWidth * beatPos + chip * chipWidth + bitPos;
          if (line->bitArr[bit] == 0) {
            line->setBit(bit, true);
          } else {
            // conflict with the previous fault
            // do it again
            i--;
          }
        }
      }
    } else if (numErrors == SingleSingleSingleOn18Symbol){
      numInherentFaults=3;
      int chip = rand() % (ChannelWidth / chipWidth);
      int bitPos = rand() % chipWidth;
      int prevbeat = -1;
      for (int j = 0; j < 3; j++) {
        int beatPos = rand() % height;
        int bit = ChannelWidth * beatPos + chip * chipWidth + bitPos;
        if (prevbeat == -1) {
          prevbeat = beatPos;
        }else{
          if(prevbeat/16 == beatPos/16){
            //symbol size is 16bit...
            j--;
            continue;
          }
        }
        if (line->bitArr[bit] == 0) {
            line->setBit(bit, true);
          } else {
            // conflict with the previous fault
            // do it again
            j--;
          }
      }
    } else if (numErrors == SingleSingleOn18Symbol){
      numInherentFaults=2;
      int chip = rand() % (ChannelWidth / chipWidth);
      int bitPos = rand() % chipWidth;
      int prevbeat = -1;
      for (int j = 0; j < 2; j++) {
        int beatPos = rand() % height;
        int bit = ChannelWidth * beatPos + chip * chipWidth + bitPos;
        if (prevbeat == -1) {
          prevbeat = beatPos;
        }else{
          if(prevbeat/16 == beatPos/16){
            //symbol size is 16bit...
            j--;
            continue;
          }
        }
        if (line->bitArr[bit] == 0) {
            line->setBit(bit, true);
          } else {
            // conflict with the previous fault
            // do it again
            j--;
          }
        }
      } else {
      // this should not happen
      exit(0);
    }
  }
  void genRandomError(CacheLine *line) {
    std::binomial_distribution<int> distribution(line->getBitN(),
                                                 cellFaultRate);
    int faultyCellCount = distribution(randomGenerator);
    // assert(line->isZero());
    numInherentFaults = faultyCellCount;
    int ChannelWidth = line->getChannelWidth();
    int chipWidth = line->getChipWidth();
    int height = line->getBeatHeight();
    int chip;
    std::list<int> chip_list;
    std::list<int>::iterator iter;
    for (int i = 0; i < faultyCellCount; i++) {
      do {
        chip = rand() % (ChannelWidth / chipWidth);
        iter = std::find(chip_list.begin(), chip_list.end(), chip);
      } while (iter != chip_list.end());

      int bitPos = rand() % chipWidth;
      int beatPos = rand() % height;
      int bit = ChannelWidth * beatPos + chip * chipWidth + bitPos;
      if (line->bitArr[bit] == 0) {
        line->bitArr[bit] = 1;
      } else {
        // conflict with the previous fault
        // do it again
        i--;
      }
    }
  }
  double getCFR() { return cellFaultRate; }
  double getElapsedTime() { return elapsedTime; }

  double elapsedTime;
  double newWeakCellRate;
};

/**
 * @brief old dummy inherent fault class (for comparison and debugging)
 */
class InherentFault2 : public Fault {
 public:
  InherentFault2(FaultDomain *fd, double _cellFaultRate, int _maxFault)
      : Fault("Inherent2") {
    cellFaultRate = _cellFaultRate;
    maxFault = _maxFault;
  }

  void genRandomError(CacheLine *line) {
    std::binomial_distribution<int> distribution(line->getBitN(),
                                                 cellFaultRate);
    int faultyCellCount = distribution(randomGenerator);
    assert(line->isZero());
    int limit;
    if (faultyCellCount > maxFault) {
      limit = maxFault;
    } else {
      limit = faultyCellCount;
    }
    for (int i = 0; i < limit; i++) {
      int bitPos = rand() % line->getBitN();
      if (line->bitArr[bitPos] == 0) {
        line->bitArr[bitPos] = 1;
      } else {
        // conflict with the previous fault
        // do it again
        i--;
      }
    }
  }

 protected:
  int maxFault;
};

#endif /* __FAULT_HH__ */
